{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a3beb2db",
   "metadata": {},
   "source": [
    "# 10.2 NGSolve with tensor-flow\n",
    "\n",
    "M. Feischl + J. Schöberl "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d0e34c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ngsolve import *\n",
    "from ngsolve.webgui import Draw\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2544b4f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh = Mesh(unit_square.GenerateMesh(maxh=0.2))\n",
    "\n",
    "fes = H1(mesh, order=2, dirichlet=\".*\")\n",
    "u,v = fes.TnT()\n",
    "a = BilinearForm(grad(u)*grad(v)*dx)\n",
    "f = LinearForm(v*dx)\n",
    "gfu = GridFunction(fes)\n",
    "\n",
    "deform = GridFunction(VectorH1(mesh, order=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a9c52a2",
   "metadata": {},
   "source": [
    "solve a parametric problem:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a300fa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Solve(Ax, Ay):\n",
    "    mesh.UnsetDeformation()\n",
    "    deform.Interpolate ( (x*y*Ax, x*y*Ay) )\n",
    "    mesh.SetDeformation(deform)\n",
    "    a.Assemble()\n",
    "    f.Assemble()\n",
    "    gfu.vec.data = a.mat.Inverse(fes.FreeDofs()) * f.vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2486bba",
   "metadata": {},
   "outputs": [],
   "source": [
    "Solve(0.8, 0.5)\n",
    "Draw (gfu);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c4402a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.asarray (gfu.vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68f0e703",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_data = 50   # number of datapoints\n",
    "input_dim = 2   # dimension of each datapoint\n",
    "data_in = np.random.uniform(0,1,size=(n_data,input_dim))  # artificial datapoints\n",
    "# print (data_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2af79d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dim = fes.ndof\n",
    "data_out = np.zeros((n_data, output_dim))\n",
    "\n",
    "for i, (ax,ay) in enumerate(data_in):\n",
    "    Solve (ax, ay)\n",
    "    data_out[i,:] = gfu.vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3416c27b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print (output_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e3602f7",
   "metadata": {},
   "source": [
    "start the training ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4f382e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# func = 'relu' #activation function\n",
    "func = 'swish' #activation function\n",
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Dense(10,input_shape=(input_dim,),activation=func),\n",
    "    tf.keras.layers.Dense(20,activation=func),\n",
    "    tf.keras.layers.Dense(50,activation=func),\n",
    "    tf.keras.layers.Dense(output_dim)\n",
    "    ])\n",
    "\n",
    "# Standard Adam optimizer\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)\n",
    "\n",
    "# Loss function (error functional) mean squared error (least squares)\n",
    "@tf.function\n",
    "def loss_fn(y_true, y_pred):\n",
    "   return tf.math.reduce_mean(tf.math.square(y_true-y_pred))\n",
    "\n",
    "loss_fn = tf.keras.losses.MeanSquaredError()\n",
    "\n",
    "# Training\n",
    "epochs = 1000 # number of times the data is used for training\n",
    "batch_size = 1024  # each gradient descent step uses 1024 datapoints\n",
    "model.compile(optimizer=optimizer,loss=loss_fn)\n",
    "model.fit(data_in,data_out,epochs=epochs,batch_size=batch_size,verbose=0)\n",
    "print(loss_fn(data_out,model(data_in)).numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1332a2c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = model(np.array([[0.1,0.2]]))\n",
    "outputvec = output.numpy().flatten()\n",
    "\n",
    "Solve (0.1, 0.2)\n",
    "Draw (gfu)\n",
    "\n",
    "gfumodel = GridFunction(fes)\n",
    "gfumodel.vec.FV()[:] = outputvec\n",
    "    \n",
    "Draw (gfumodel)\n",
    "Draw (gfu-gfumodel, mesh);\n",
    "print (\"err = \", Integrate((gfu-gfumodel)**2, mesh)**0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d12c1dd0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2734f36d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
